import datetime
import json
from collections import Counter

import akshare as ak
import matplotlib.dates as mdate
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from python_quant.stock.myTT import HHV, LLV


def max_drawdown(ycapital):
    # 计算每日的回撤
    drawdown = []
    tmp_max_capital = ycapital[0]
    for c in ycapital:
        tmp_max_capital = max(c, tmp_max_capital)
        drawdown.append(1 - c / tmp_max_capital)

    # 最大回撤
    maxDrawdown = max(drawdown)
    # 计算最大回撤日期范围
    endidx = np.argmax(drawdown)
    # enddate = xdate[endidx]

    startidx = np.argmax(ycapital[:endidx])
    # startdate = xdate[startidx]
    # 仅仅画图的话，我们只要索引值更加方便
    return maxDrawdown, startidx, endidx


def max_drawdown_duration(xdate, ycapital):
    duration = []
    tmp_max_capital = ycapital[0]
    for c in ycapital:
        if c >= tmp_max_capital:
            duration.append(0)
        else:
            duration.append(duration[-1] + 1)
        tmp_max_capital = max(c, tmp_max_capital)

    MaxDDD = max(duration)

    # fig, ax = plt.subplots(figsize = (21, 9))
    # plt.plot(xdate, duration)
    # ax.xaxis.set_major_formatter(mdate.DateFormatter('%Y-%m-%d'))
    endidx = np.argmax(duration)
    startidx = endidx - MaxDDD

    return MaxDDD, startidx, endidx


def max_drawdown_restore_time(startidx, endidx, xdate, ycapital):
    """
        startidx:表示最大回撤的开始时间在 xdate 中的索引，由 max_drawdown 方法返回
        endidx:表示最大回撤的结束时间在 xdate 中的索引，由 max_drawdown 方法返回
    """
    maxdd_resore_time = 0
    restore_endidx = np.inf
    for t in range(endidx, len(xdate)):
        if ycapital[t] >= ycapital[startidx]:
            restore_endidx = t
            break
        else:
            maxdd_resore_time += 1

    restore_endidx = min(restore_endidx, len(xdate) - 1)
    return maxdd_resore_time, restore_endidx


def plot(xdate, ycapital):
    # 指定画布大小
    fig, ax = plt.subplots(figsize=(21, 9))
    # 绘图并设置颜色，图例标签，线宽
    plt.plot(xdate, ycapital, 'red', label='My Strategy', linewidth=2)

    # 绘制最大回撤日期范围标识 marker = 'v'
    MaxDrawdown, startidx, endidx = max_drawdown(xdate, ycapital)
    print("最大回撤为：", MaxDrawdown)
    plt.scatter([xdate[startidx], xdate[endidx]], [ycapital[startidx], ycapital[endidx]],
                s=100, c='b', marker='s', label='MaxDrawdown')
    # 绘制最大回撤恢复时间
    maxdd_resore_time, restore_endidx = max_drawdown_restore_time(startidx, endidx, xdate, ycapital)
    print("最大回撤恢复时间为（天）：", maxdd_resore_time)
    plt.scatter([xdate[endidx], xdate[restore_endidx]], [ycapital[endidx], ycapital[restore_endidx]],
                s=100, c='cyan', marker='D', label='MaxDrawdown Restore Time')

    # 绘制最大回撤持续期标识 marker = 'D'
    MaxDDD, startidx, endidx = max_drawdown_duration(xdate, ycapital)
    plt.scatter([xdate[startidx], xdate[endidx]], [ycapital[startidx], ycapital[endidx]],
                s=80, c='g', marker='v', label='MaxDrawdown Duration')
    print("最大回撤持续期为（天）：", MaxDDD)
    # 设置刻度值颜色
    plt.yticks(color='gray')
    # 设置 y 轴百分比显示，注意将 y 轴数据乘以 100
    # ax.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.2f%%'))
    # 颜色，旋转刻度
    plt.xticks(color='gray', rotation=15)
    # 指定字体，大小，颜色
    fontdict = {"family": "Times New Roman", 'size': 12, 'color': 'gray'}  # Times New Roman, Arial
    plt.title("random account value", fontdict=fontdict)
    plt.xlabel("date(day)", fontdict=fontdict)
    plt.ylabel("account value", fontdict=fontdict)
    # 去掉边框 top left right bottom
    ax.spines['top'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.spines['right'].set_visible(False)
    # 设置 x 轴颜色
    ax.spines['bottom'].set_color('lightgray')
    # 设置时间标签显示格式
    ax.xaxis.set_major_formatter(mdate.DateFormatter('%Y-%m-%d'))
    # 设置时间刻度间隔
    # timedelta = (xdate[-1] - xdate[0]) / 10  # 这种方式不能保证显示最后一个日期
    # plt.xticks(mdate.drange(xdate[0], xdate[-1], timedelta))
    # 分成 10 份
    delta = round(len(xdate) / 9)
    plt.xticks([xdate[i * delta] for i in range(9)] + [xdate[-1]])
    # 通过修改tick_spacing的值可以修改x轴的密度
    # ax.xaxis.set_major_locator(ticker.MultipleLocator(10))
    # 去掉 y 轴刻度线,四个方向均可设置
    plt.tick_params(left='off')
    # 设置刻度的朝向，宽，长度
    plt.tick_params(which='major', direction='out', width=0.2, length=5)  # in, out or inout
    # 设置刻度显示在哪个方向上
    # tick_params(labeltop='on',labelbottom='off',labelleft='off',labelright='off')
    # 设置 y 轴方向的网络线
    plt.grid(axis='y', color='lightgray', linestyle='-', linewidth=0.5)

    # 设置图例 列宽：columnspacing=float (upper left)
    plt.legend(loc='best', fontsize=12, frameon=False, ncol=1)
    # 设置图例字体颜色
    # leg = 上一行 plt.legend 的返回值
    # for line,text in zip(leg.legendHandles, leg.get_texts()):
    #    text.set_color(line.get_color())

    fig.show()
    # fig.savefig("test.png")  # dpi = 150


def analysis_stock_trade(symbol, signal_date_list, keep_days, tp_percent, sl_percent):
    analysis_list = list()
    for signal_date in signal_date_list:
        # 下单日期
        trade_date = (signal_date + datetime.timedelta(days=1)).strftime('%Y%m%d')

        # k线周期结束日期
        now_date = datetime.datetime.now().strftime('%Y%m%d')

        stock_zh_a_hist_df = ak.stock_zh_a_hist(symbol=symbol, period="daily", start_date=trade_date,
                                                end_date=now_date, adjust="hfq")
        print(symbol, trade_date, len(stock_zh_a_hist_df))
        if stock_zh_a_hist_df.empty:
            analysis_dict = {
                "代码": symbol,
                "信号日": signal_date,
                "买入日期": np.nan
            }
            analysis_list.append(analysis_dict)
            continue

        if len(stock_zh_a_hist_df) == 1:
            analysis_dict = {
                "代码": symbol,
                "信号日": signal_date,
                "买入日期": trade_date
            }
            analysis_list.append(analysis_dict)
            continue

        # 取21个交易日的k线数据
        stock_zh_a_hist_cut_df = stock_zh_a_hist_df.iloc[0:1 + keep_days]

        # 交易周期结束日期
        trade_end_date = stock_zh_a_hist_cut_df.iloc[-1]["日期"]

        # 取收盘价序列
        C = stock_zh_a_hist_cut_df['收盘'].values

        # 取开盘价序列
        O = stock_zh_a_hist_cut_df['开盘'].values

        # 取最高价序列
        H = stock_zh_a_hist_cut_df['最高'].values

        # 取最低价序列
        L = stock_zh_a_hist_cut_df['最低'].values

        # 取时间序列
        DATE = stock_zh_a_hist_cut_df['日期'].values

        # 信号日后的第一根k线的开盘价作为买入价格
        buy_price = O[0]

        # 取买入k线后的20根k线，作为统计k线
        A_O = O[1:]
        A_C = C[1:]
        A_H = H[1:]
        A_L = L[1:]
        A_DATE = DATE[1:]

        # 每根k线的累计涨幅
        up_rate_acc = np.around((A_H - buy_price) / buy_price * 100, 2)
        # 每根k线上的累计跌幅
        down_rate_acc = np.around((A_L - buy_price) / buy_price * 100, 2)

        # 寻找区间最高价
        max_high = HHV(A_H, len(A_H))[-1]
        # 寻找区间最低价
        min_low = LLV(A_L, len(A_L))[-1]

        # 最大涨幅对应的日期
        max_high_index = np.argmax(A_H)
        max_high_date = A_DATE[max_high_index]
        # 最大跌幅对应的日期
        min_low_index = np.argmin(A_L)
        min_low_date = A_DATE[min_low_index]
        # 是否先达到最大涨幅后达到最大跌幅
        up_first_down_after = True if max_high_date < min_low_date else False
        # 离场日期
        close_date = trade_end_date

        # 首次达到止盈对应的日期
        first_tp_date = None
        first_tp_up_rate = None

        tp_index = np.argwhere(up_rate_acc >= tp_percent)
        if len(tp_index) > 0:
            first_tp_date = A_DATE[tp_index[0][0]]
            # 止盈的实际止盈比例
            first_tp_up_rate = up_rate_acc[tp_index[0][0]]

        # 首次达到止损对应的日期
        first_sl_date = None
        first_sl_down_rate = None

        sl_index = np.argwhere(down_rate_acc <= sl_percent)
        if len(sl_index) > 0:
            first_sl_date = A_DATE[sl_index[0][0]]
            # 实际止损比例
            first_sl_down_rate = down_rate_acc[sl_index[0][0]]

        # 成功失败标记，打止盈或者周期结束为盈利视为成功，打止损或者周期结束为亏损视为失败

        is_success_trade = False
        # 打止盈
        condition1 = first_tp_date is not None and first_sl_date is None
        condition2 = first_tp_date is not None and first_sl_date is not None and first_tp_date < first_sl_date
        condition3 = first_sl_date is not None and first_sl_down_rate > 0
        if condition1 or condition2 or condition3:
            close_date = first_tp_date
            if condition3:
                close_date = first_sl_date
            is_success_trade = True

        if first_sl_date is not None and (
                (first_tp_date is not None and first_tp_date > first_sl_date) or first_tp_date is None):
            close_date = first_sl_date
            is_success_trade = False

        analysis_dict = {
            "代码": symbol,
            "信号日": signal_date,
            "买入日期": trade_date,
            "离场日期": close_date,
            "交易周期结束日期": trade_end_date,
            "开仓价/开盘价": buy_price,
            "是否是成功的交易": is_success_trade,
            "是否先触及最大涨幅再触及最大跌幅": up_first_down_after,
            "首次达到止盈对应的日期": first_tp_date,
            "实际的止盈比例": first_tp_up_rate,
            "最大涨幅": max(up_rate_acc),
            "最大涨幅对应日期": max_high_date,
            "首次达到止损对应的日期": first_sl_date,
            "实际的止损比例": first_sl_down_rate,
            "最大跌幅": min(down_rate_acc),
            "最大跌幅对应的日期": min_low_date,
            "最高价": max_high,
            "最低价": min_low
        }

        analysis_list.append(analysis_dict)

    analysis_df = pd.DataFrame(analysis_list)
    if analysis_df.empty:
        return analysis_df
    analysis_df = analysis_df.sort_values(by="买入日期")
    # analysis_df.style.format("{:.2f}")
    return analysis_df


def calc_field_value_times(data_pd, field, value):
    """
    计算连续数据
    :param data_pd: 要处理的pandas数据集
    :param field: 要计算的字段
    :param value: 值
    :return:
    """
    # 判断值是否存在
    if data_pd.query("%s == %s" % (field, value)).empty:
        return 0

    data_pd["subgroup"] = data_pd[field].ne(data_pd[field].shift()).cumsum()
    return data_pd.groupby([field, "subgroup"]).apply(len)[value].max()


def analyze_df(analysis_df, code):
    # analysis_filter_df = analysis_df[analysis_df["是否是成功的交易"].notnull()]

    total_size = len(analysis_df)
    success_trade_size = len(analysis_df[(analysis_df['是否是成功的交易'] == True)])
    ana_dict = {'代码': code, '总交易次数': total_size, '成功交易次数': success_trade_size}

    if ana_dict['总交易次数'] > 0:
        ana_dict['胜率'] = ana_dict['成功交易次数'] / ana_dict['总交易次数'] * 100
    cur_time, max_time, loss_times, count_loss_times = analysis_keep_loss_times(analysis_df['是否是成功的交易'])
    ana_dict['最新连损次数'] = cur_time
    ana_dict['最大连损次数'] = max_time
    ana_dict['剩余连损次数'] = max_time - cur_time
    loss_times = ','.join(str(item) for item in loss_times)
    ana_dict['连损序列'] = loss_times
    ana_dict['连损次数频率统计'] = json.dumps(dict(count_loss_times))
    ana_dict['是否存在最新未结束信号'] = False
    if pd.isna(analysis_df.iloc[-1]['是否是成功的交易']):
        ana_dict['是否存在最新未结束信号'] = True
        ana_dict['信号日'] = analysis_df.iloc[-1]['信号日']
    # df_e = df[(df['胜率'] > 50)]
    ana_df = pd.DataFrame(ana_dict, index=[0])
    return ana_df


def keep_loss_times(list_data):
    """
        连损次数
    """
    max_time = 0  # 已知最大连续出现次数初始为0
    cur_time = 1  # 记录当前元素是第几次连续出现
    pre_element = None  # 记录上一个元素是什么

    for i in list_data:
        if i == pre_element:  # 如果当前元素和上一个元素相同,连续出现次数+1,并更新最大值
            cur_time += 1
            max_time = max((cur_time, max_time))
        else:  # 不同则刷新计数器
            pre_element = i
            cur_time = 1
    return cur_time, max_time


def analysis_keep_loss_times(list_data):
    """
        连损次数
    """
    max_time = 0  # 已知最大连续出现次数初始为0
    cur_time = 1  # 记录当前元素是第几次连续出现
    pre_element = None  # 记录上一个元素是什么
    loss_times = list()

    for i in list_data:
        if i == pre_element:  # 如果当前元素和上一个元素相同,连续出现次数+1,并更新最大值
            cur_time += 1
            max_time = max((cur_time, max_time))
        else:  # 不同则刷新计数器
            loss_times.append(cur_time)
            pre_element = i
            cur_time = 1

    return cur_time, max_time, loss_times, Counter(loss_times)
